{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Untitled1.ipynb",
      "provenance": [],
      "machine_shape": "hm",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/jalammar/jalammar.github.io/blob/master/notebooks/distilbert_masked.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "To9ENLU90WGl",
        "colab_type": "code",
        "outputId": "b0de1b40-ba6f-4904-de44-20d00394338b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 632
        }
      },
      "source": [
        "!pip install transformers"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting transformers\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/fd/f9/51824e40f0a23a49eab4fcaa45c1c797cbf9761adedd0b558dab7c958b34/transformers-2.1.1-py3-none-any.whl (311kB)\n",
            "\r\u001b[K     |█                               | 10kB 19.7MB/s eta 0:00:01\r\u001b[K     |██                              | 20kB 2.1MB/s eta 0:00:01\r\u001b[K     |███▏                            | 30kB 3.2MB/s eta 0:00:01\r\u001b[K     |████▏                           | 40kB 2.1MB/s eta 0:00:01\r\u001b[K     |█████▎                          | 51kB 2.5MB/s eta 0:00:01\r\u001b[K     |██████▎                         | 61kB 3.0MB/s eta 0:00:01\r\u001b[K     |███████▍                        | 71kB 3.5MB/s eta 0:00:01\r\u001b[K     |████████▍                       | 81kB 4.0MB/s eta 0:00:01\r\u001b[K     |█████████▌                      | 92kB 4.4MB/s eta 0:00:01\r\u001b[K     |██████████▌                     | 102kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████▋                    | 112kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████▋                   | 122kB 3.4MB/s eta 0:00:01\r\u001b[K     |█████████████▊                  | 133kB 3.4MB/s eta 0:00:01\r\u001b[K     |██████████████▊                 | 143kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████████▊                | 153kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████████▉               | 163kB 3.4MB/s eta 0:00:01\r\u001b[K     |█████████████████▉              | 174kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████████████             | 184kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████████████            | 194kB 3.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████           | 204kB 3.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████          | 215kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████▏        | 225kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████▏       | 235kB 3.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▎      | 245kB 3.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▎     | 256kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▍    | 266kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▍   | 276kB 3.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▍  | 286kB 3.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▌ | 296kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▌| 307kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 317kB 3.4MB/s \n",
            "\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from transformers) (4.28.1)\n",
            "Collecting regex\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/e3/8e/cbf2295643d7265e7883326fb4654e643bfc93b3a8a8274d8010a39d8804/regex-2019.11.1-cp36-cp36m-manylinux1_x86_64.whl (643kB)\n",
            "\u001b[K     |████████████████████████████████| 645kB 32.2MB/s \n",
            "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.21.0)\n",
            "Collecting sacremoses\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/1f/8e/ed5364a06a9ba720fddd9820155cc57300d28f5f43a6fd7b7e817177e642/sacremoses-0.0.35.tar.gz (859kB)\n",
            "\u001b[K     |████████████████████████████████| 860kB 41.4MB/s \n",
            "\u001b[?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.10.7)\n",
            "Collecting sentencepiece\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/14/3d/efb655a670b98f62ec32d66954e1109f403db4d937c50d779a75b9763a29/sentencepiece-0.1.83-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
            "\u001b[K     |████████████████████████████████| 1.0MB 50.0MB/s \n",
            "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.17.3)\n",
            "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2019.9.11)\n",
            "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n",
            "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.8)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.0)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.0)\n",
            "Requirement already satisfied: botocore<1.14.0,>=1.13.7 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.13.7)\n",
            "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.4)\n",
            "Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.2.1)\n",
            "Requirement already satisfied: python-dateutil<3.0.0,>=2.1; python_version >= \"2.7\" in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.7->boto3->transformers) (2.6.1)\n",
            "Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.7->boto3->transformers) (0.15.2)\n",
            "Building wheels for collected packages: sacremoses\n",
            "  Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for sacremoses: filename=sacremoses-0.0.35-cp36-none-any.whl size=883999 sha256=99b7a3619734c3954dc5827bd71ed0c5d31922622e7d8081fbafbd41a3bae471\n",
            "  Stored in directory: /root/.cache/pip/wheels/63/2a/db/63e2909042c634ef551d0d9ac825b2b0b32dede4a6d87ddc94\n",
            "Successfully built sacremoses\n",
            "Installing collected packages: regex, sacremoses, sentencepiece, transformers\n",
            "Successfully installed regex-2019.11.1 sacremoses-0.0.35 sentencepiece-0.1.83 transformers-2.1.1\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fvFvBLJV0Dkv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "from sklearn.model_selection import cross_val_score\n",
        "import torch\n",
        "import transformers as ppb\n",
        "import warnings\n",
        "warnings.filterwarnings('ignore')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "POT8e8VX0Ry1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cyoj29J24hPX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "df = pd.read_csv('https://github.com/clairett/pytorch-sentiment-classification/raw/master/data/SST2/train.tsv', delimiter='\\t', header=None)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gTM3hOHW4hUY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "batch_1 = df[:2000]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jGvcfcCP5xpZ",
        "colab_type": "code",
        "outputId": "bb9e67d2-040f-42aa-f913-e486eb942825",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        }
      },
      "source": [
        "batch_1[1].value_counts()"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "1    1041\n",
              "0     959\n",
              "Name: 1, dtype: int64"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "q1InADgf5xm2",
        "colab_type": "code",
        "outputId": "03dbeec3-0674-49ea-ff1d-181e75095d6a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        }
      },
      "source": [
        "model_class, tokenizer_class, pretrained_weights = (ppb.DistilBertModel, ppb.DistilBertTokenizer, 'distilbert-base-uncased')\n",
        "\n",
        "\n",
        "# Load pretrained model/tokenizer\n",
        "tokenizer = tokenizer_class.from_pretrained(pretrained_weights)\n",
        "model = model_class.from_pretrained(pretrained_weights)"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 231508/231508 [00:00<00:00, 2645593.78B/s]\n",
            "100%|██████████| 492/492 [00:00<00:00, 163453.27B/s]\n",
            "100%|██████████| 267967963/267967963 [00:03<00:00, 71733827.75B/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Dg82ndBA5xlN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tokenized = batch_1[0].apply((lambda x: tokenizer.encode(x, add_special_tokens=True)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "URn-DWJt5xhP",
        "colab_type": "code",
        "outputId": "577219b3-9665-42b6-d137-291fe7c63a9c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "max_len = 0\n",
        "for i in tokenized.values:\n",
        "    if len(i) > max_len:\n",
        "        max_len = len(i)\n",
        "max_len"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "59"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jdi7uXo95xeq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "padded = np.array([i + [0]*(max_len-len(i)) for i in tokenized.values])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bFZMqh5w_dC0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pimXlkWoAHYd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5lEpOI5r5xbT",
        "colab_type": "code",
        "outputId": "2c8dd32d-34d8-4003-9352-959a30ce768e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "np.array(padded).shape"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(2000, 59)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KVzF4rMA_m4n",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4K_iGRNa_Ozc",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "e423abd7-507c-4dbb-c085-b01ff1d94088"
      },
      "source": [
        "attention_mask = np.where(padded != 0, 1, 0)\n",
        "attention_mask.shape"
      ],
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(2000, 59)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 25
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "39UVjAV56PJz",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "input_ids = torch.tensor(padded)  \n",
        "attention_mask = torch.tensor(attention_mask)\n",
        "\n",
        "with torch.no_grad():\n",
        "    last_hidden_states = model(input_ids, attention_mask=attention_mask)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C9t60At16PVs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "features = last_hidden_states[0][:,0,:].numpy()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JD3fX2yh6PTx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "labels = batch_1[1]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ddAqbkoU6PP9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from sklearn.model_selection import train_test_split\n",
        "train_features, test_features, train_labels, test_labels = train_test_split(features, labels)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cyEwr7yYD3Ci",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        },
        "outputId": "6bf19baa-6241-4c0a-a9c0-eb1d19fcf058"
      },
      "source": [
        "from sklearn.linear_model import LogisticRegression\n",
        "from sklearn.model_selection import GridSearchCV\n",
        "\n",
        "parameters = {'C': np.linspace(0.0001, 100, 20)}\n",
        "grid_search = GridSearchCV(LogisticRegression(), parameters)\n",
        "grid_search.fit(train_features, train_labels)\n",
        "\n",
        "print('best parameters: ', grid_search.best_params_)\n",
        "print('best scrores: ', grid_search.best_score_)"
      ],
      "execution_count": 69,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "best parameters:  {'C': 5.263252631578947}\n",
            "best scrores:  0.8286666666666667\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iCoyxRJ7ECTA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dz_fVWx6EGI9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kVa-s4Jj6POm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BCQ75kyGF9Jn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2y8LLwEeGSOL",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gG-EVWx4CzBc",
        "colab_type": "code",
        "outputId": "3d8018eb-ec5b-44b5-96f9-349d1de71671",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "lr_clf = LogisticRegression()\n",
        "lr_clf.fit(train_features, train_labels)\n",
        "lr_clf.score(test_features, test_labels)"
      ],
      "execution_count": 65,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.818"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 65
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SB3y3iu26PHM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BjkK2NBZ4hJE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hTv5YpnN7iyx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tzQYSP-q7i7T",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8xb3_A7cEby9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lnwgmqNG7i5l",
        "colab_type": "code",
        "outputId": "94422731-7381-4a56-8563-9e2f82cb5ec2",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "from sklearn.dummy import DummyClassifier\n",
        "clf = DummyClassifier()\n",
        "\n",
        "scores = cross_val_score(clf, train_features, train_labels)\n",
        "print(\"Dummy classifier score: %0.3f (+/- %0.2f)\" % (scores.mean(), scores.std() * 2))"
      ],
      "execution_count": 70,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Dummy classifier score: 0.528 (+/- 0.02)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YYwrotgc7i31",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JDuQ6im17iv_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GW8VoPCl0Z3q",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VDQH4khC0vCY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Cn5JzCJ14AMZ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hiPd3yok4A3i",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}